import dgl
import dgl.function as fn
from dgl.nn import SAGEConv

import torch
import torch.nn as nn
import torch.nn.functional as F
import copy

import unittest
import time

from sampler import LocalNeighborSampler, GlobalNeighborSampler
from data.load_dataset import load_ogb, load_partitioned_graphs, load_complete_graphs

class TestCustomSAGEConv(unittest.TestCase):
    def test_forward(self):
        # Create a larger graph
        g = dgl.graph([(0, 2), (0, 3), (1, 2), (1, 4), (2, 5), (3, 4), (3, 6), (4, 5), (4, 6), (5, 6)])
        
        # Sample neighbors for nodes 3, 4, and 5 (remote nodes)
        remote_nodes = torch.tensor([3, 4, 5])
        remote_frontier = dgl.sampling.sample_neighbors(g, remote_nodes, fanout=5)
        
        # Create a remote block from the frontier
        remote_block = dgl.to_block(remote_frontier, remote_nodes)
        
        # Sample neighbors for nodes 2 and 6 (local nodes)
        local_nodes = torch.tensor([2, 6])
        local_frontier = dgl.sampling.sample_neighbors(g, local_nodes, fanout=5)
        
        # Create a local block from the frontier
        local_block = dgl.to_block(local_frontier, local_nodes)
        
        # Assign node features
        in_feats = 16
        g.ndata['feat'] = torch.randn(g.num_nodes(), in_feats)
        
        # Extract features for remote_block and local_block
        remote_feat = g.ndata['feat'][remote_block.srcdata[dgl.NID]]
        local_feat = g.ndata['feat'][local_block.srcdata[dgl.NID]]
        
        # Initialize the CustomSAGEConv layer
        out_feats = 32
        activation = torch.nn.functional.relu
        conv = CustomSAGEConv(in_feats, out_feats, activation)
        print("CustomSAGEConv:", conv)
        
        # Perform forward pass
        h = conv(local_block, remote_block, local_feat, remote_feat)
        print(h)

        # Check the output shape
        expected_shape = (local_block.num_dst_nodes() + remote_block.num_dst_nodes(), out_feats)
        self.assertEqual(h.shape, expected_shape)
        
        # Check if the activation is applied correctly
        self.assertTrue(torch.all(h >= 0))

        # Compare with the official SAGEConv implementation
        official_conv = SAGEConv(in_feats, out_feats, activation=activation, aggregator_type='mean')
        official_conv.fc_neigh.weight.data = conv.fc_neigh.weight.data
        official_conv.fc_self.weight.data = conv.fc_self.weight.data
        official_conv.fc_self.bias.data = conv.fc_self.bias.data
        
        official_h = official_conv(g, g.ndata['feat'])
        official_h_local = official_h[local_block.dstdata[dgl.NID]]
        official_h_remote = official_h[remote_block.dstdata[dgl.NID]]
        official_h_combined = torch.cat([official_h_local, official_h_remote], dim=0)
        print("Official h:", official_h_combined)

        self.assertTrue(torch.allclose(h, official_h_combined, rtol=1e-4, atol=1e-4))

def compute_acc(pred, labels):
    """
    Compute the accuracy of prediction given the labels.
    """
    labels = labels.long()
    return (torch.argmax(pred, dim=1) == labels).float().sum() / len(pred)

def evaluate(model, g, valid_sampler, device, valid_nid, args):
    """
    Evaluate the model on the validation set specified by ``val_nid``.
    """
    valid_shuffle_index = torch.randperm(len(valid_nid))
    valid_shuffle_nids = copy.deepcopy(valid_nid)
    valid_shuffle_nids = valid_shuffle_nids[valid_shuffle_index]
    valid_shuffle_nids = valid_shuffle_nids[:args.valid_batch_size]
    
    seeds = valid_shuffle_nids
    blocks, transform_indices = valid_sampler.sample(g, seeds)
    # Moving to the device
    for i in range(len(blocks)):
        blocks[i] = (blocks[i][0].to(device), blocks[i][1].to(device))

    # Extract features
    first_local_block, first_remote_block = blocks[0]
    local_feat_src = first_local_block.srcdata['features'].to(device)
    remote_feat_src = first_remote_block.srcdata['features'].to(device)

    # Extract labels
    last_local_block, last_remote_block = blocks[-1]
    assert last_remote_block.number_of_edges() == 0 # The last remote block should be empty
    labels = last_local_block.dstdata['labels'].to(device)

    model.eval()
    with torch.no_grad():
        batch_pred = model(blocks, local_feat_src, remote_feat_src, transform_indices)
    model.train()
    return compute_acc(batch_pred, labels)

class CustomSAGEConv(nn.Module):
    def __init__(self, in_feats, out_feats, activation=None, norm=None, bias=True):
        super(CustomSAGEConv, self).__init__()
        self.in_feats = in_feats
        self.out_feats = out_feats
        
        # Determine whether to apply linear transformation before message passing
        self.lin_before_mp = in_feats > out_feats

        self.fc_neigh = nn.Linear(in_feats, out_feats, bias=False)
        self.fc_self = nn.Linear(in_feats, out_feats, bias=bias)
        
        self.activation = activation
        self.norm = norm

    def forward(self, local_block, remote_block, local_feat_src, remote_feat_src):
        # This should be calculated on the client
        with local_block.local_scope():
            local_block.srcdata['h'] = (
                self.fc_neigh(local_feat_src) if self.lin_before_mp else local_feat_src
            )
            # Create 'm' on dst nodes by copying 'h' from src nodes, then perform mean aggregation.
            local_block.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'neigh'))
            local_neigh = local_block.dstdata['neigh']
            local_self = local_feat_src[:local_block.number_of_dst_nodes()]
            if not self.lin_before_mp:
                local_neigh = self.fc_neigh(local_neigh)

        # This should be calculated on the server technically
        with remote_block.local_scope():
            remote_block.srcdata['h'] = (
                self.fc_neigh(remote_feat_src) if self.lin_before_mp else remote_feat_src
            )
            # Create 'm' on dst nodes by copying 'h' from src nodes, then perform mean aggregation.
            remote_block.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'neigh'))
            remote_neigh = remote_block.dstdata['neigh']
            remote_self = remote_feat_src[:remote_block.number_of_dst_nodes()]
            if not self.lin_before_mp:
                remote_neigh = self.fc_neigh(remote_neigh)

        # The following should be calculated on the client
        # Combine the results
        h_neigh = torch.cat([local_neigh, remote_neigh], dim=0)
        h_self = torch.cat([local_self, remote_self], dim=0)
        h_self = self.fc_self(h_self)
        # assert the are of the same shape
        assert h_neigh.shape == h_self.shape
        h = h_self + h_neigh

        # Apply activation
        if self.activation is not None:
            h = self.activation(h)
        # Normalize
        if self.norm is not None:
            h = self.norm(h)

        return h

class ProfileCustomSAGEConv(nn.Module):
    def __init__(self, in_feats, out_feats, activation=None, norm=None, bias=True):
        super(ProfileCustomSAGEConv, self).__init__()
        self.in_feats = in_feats
        self.out_feats = out_feats
        
        # Determine whether to apply linear transformation before message passing
        self.lin_before_mp = in_feats > out_feats

        self.fc_neigh = nn.Linear(in_feats, out_feats, bias=False)
        self.fc_self = nn.Linear(in_feats, out_feats, bias=bias)
        
        self.activation = activation
        self.norm = norm

    def forward(self, local_block, remote_block, local_feat_src, remote_feat_src):
        # This should be calculated on the client
        with local_block.local_scope():
            local_block.srcdata['h'] = (
                self.fc_neigh(local_feat_src) if self.lin_before_mp else local_feat_src
            )
            # Create 'm' on dst nodes by copying 'h' from src nodes, then perform mean aggregation.
            local_block.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'neigh'))
            local_neigh = local_block.dstdata['neigh']
            local_self = local_feat_src[:local_block.number_of_dst_nodes()]
            if not self.lin_before_mp:
                local_neigh = self.fc_neigh(local_neigh)

        start_time = time.time()
        # This should be calculated on the server technically
        with remote_block.local_scope():
            remote_block.srcdata['h'] = (
                self.fc_neigh(remote_feat_src) if self.lin_before_mp else remote_feat_src
            )
            # Create 'm' on dst nodes by copying 'h' from src nodes, then perform mean aggregation.
            remote_block.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'neigh'))
            remote_neigh = remote_block.dstdata['neigh']
            remote_self = remote_feat_src[:remote_block.number_of_dst_nodes()]
            if not self.lin_before_mp:
                remote_neigh = self.fc_neigh(remote_neigh)
        torch.cuda.synchronize()    # Wait for the computation to be finished
        server_compu_time = time.time() - start_time


        # The following should be calculated on the client
        # Combine the results
        h_neigh = torch.cat([local_neigh, remote_neigh], dim=0)
        h_self = torch.cat([local_self, remote_self], dim=0)
        h_self = self.fc_self(h_self)
        # assert the are of the same shape
        assert h_neigh.shape == h_self.shape
        h = h_self + h_neigh

        # Apply activation
        if self.activation is not None:
            h = self.activation(h)
        # Normalize
        if self.norm is not None:
            h = self.norm(h)

        return h, server_compu_time # Return the computation time for the server
        
class ClientServerGraphSage(nn.Module):
    def __init__(self, in_feats, n_hidden, n_classes, n_layers,
                 activation, dropout):
        super().__init__()
        self.n_layers = n_layers
        self.n_hidden = n_hidden
        self.n_classes = n_classes
        self.layers = nn.ModuleList()
        if n_layers > 1:
            self.layers.append(CustomSAGEConv(in_feats, n_hidden))
            for i in range(1, n_layers - 1):
                self.layers.append(CustomSAGEConv(n_hidden, n_hidden))
            self.layers.append(CustomSAGEConv(n_hidden, n_classes))
        else:
            self.layers.append(CustomSAGEConv(in_feats, n_classes))
        self.dropout = nn.Dropout(dropout)
        self.activation = activation

    def forward(self, blocks, local_feat_src, remote_feat_src, transform_indices):
        h_local = local_feat_src
        h_remote = remote_feat_src

        for l, (layer, block) in enumerate(zip(self.layers, blocks)):
            local_block, remote_block = block
            assert h_local.shape[0] == local_block.number_of_src_nodes()
            assert h_remote.shape[0] == remote_block.number_of_src_nodes()
            h = layer(local_block, remote_block, h_local, h_remote)
            if l != len(self.layers) - 1:
                h = self.activation(h)
                h = self.dropout(h)

                # We perform the transformation on the server
                local_indices, remote_indices = transform_indices[l]
                h_local = h[local_indices]
                h_remote = h[remote_indices]

                assert h_local.shape[0] == len(local_indices)
                assert h_remote.shape[0] == len(remote_indices)
                assert h_local.shape[0] + h_remote.shape[0] == h.shape[0]
        return h
    
class ProfileClientServerGraphSage(nn.Module):
    def __init__(self, in_feats, n_hidden, n_classes, n_layers,
                 activation, dropout):
        super().__init__()
        self.n_layers = n_layers
        self.n_hidden = n_hidden
        self.n_classes = n_classes
        self.layers = nn.ModuleList()
        if n_layers > 1:
            self.layers.append(ProfileCustomSAGEConv(in_feats, n_hidden))
            for i in range(1, n_layers - 1):
                self.layers.append(ProfileCustomSAGEConv(n_hidden, n_hidden))
            self.layers.append(ProfileCustomSAGEConv(n_hidden, n_classes))
        else:
            self.layers.append(ProfileCustomSAGEConv(in_feats, n_classes))
        self.dropout = nn.Dropout(dropout)
        self.activation = activation

    def forward(self, blocks, local_feat_src, remote_feat_src, transform_indices):
        h_local = local_feat_src
        h_remote = remote_feat_src

        server_comp_times = []
        for l, (layer, block) in enumerate(zip(self.layers, blocks)):
            local_block, remote_block = block
            assert h_local.shape[0] == local_block.number_of_src_nodes()
            assert h_remote.shape[0] == remote_block.number_of_src_nodes()
            h, server_compu_time = layer(local_block, remote_block, h_local, h_remote)
            server_comp_times.append(server_compu_time)
            if l != len(self.layers) - 1:
                h = self.activation(h)
                h = self.dropout(h)

                # We perform the transformation on the server
                local_indices, remote_indices = transform_indices[l]
                h_local = h[local_indices]
                h_remote = h[remote_indices]

                assert h_local.shape[0] == len(local_indices)
                assert h_remote.shape[0] == len(remote_indices)
                assert h_local.shape[0] + h_remote.shape[0] == h.shape[0]
        return h, server_comp_times

if __name__ == '__main__':
    # unittest.main()

    # Write a test for the global sampler
    # Load complete graph
    complete_graph = load_complete_graphs('ogbn-arxiv')

    # Load partitioned graph
    path_to_partitioned_dataset = 'partitioned_dataset/ogbn-arxiv_metis_5.bin'
    partitioned_graphs = load_partitioned_graphs(path_to_partitioned_dataset)
    local_graph = partitioned_graphs[0]

    # Get all the training nodes
    train_nid_in_local_graph = local_graph.ndata['train_mask'].nonzero().squeeze()

    # Get training node original IDs in the complete graph
    local_training_node_nid = local_graph.ndata['_ID'][train_nid_in_local_graph]
    global_training_node_nid = complete_graph.ndata['train_mask'].nonzero().squeeze()
    print(local_training_node_nid)
    print(global_training_node_nid)

    # Get all the local nodes' IDs in the complete graph
    local_node_id = local_graph.ndata['_ID']

    # Build the local mask
    local_mask = torch.zeros(complete_graph.number_of_nodes(), dtype=torch.bool)
    local_mask[local_node_id] = True

    # Build global data loader
    fanouts = [5, 10, 15]
    global_sampler = GlobalNeighborSampler(fanouts, local_mask)

    # Get one blocks from the global data loader with next iter
    seeds = global_training_node_nid[:20]

    blocks = global_sampler.sample(complete_graph, seeds)

    # Print the src/dst nodes of the blocks
    for i, block_pair in enumerate(blocks):
        local_block, global_block = block_pair
        print(f"Block {i}")
        print(f"Local block: {local_block.srcdata[dgl.NID]} -> {local_block.dstdata[dgl.NID]}")
        print(f"Global block: {global_block.srcdata[dgl.NID]} -> {global_block.dstdata[dgl.NID]}")
        print()

    # Get first block
    local_block, global_block = blocks[0]

    # Assign node features
    in_feats = 16
    complete_graph.ndata['feat'] = torch.randn(complete_graph.num_nodes(), in_feats)
    # Extract features for remote_block and local_block
    global_feat = complete_graph.ndata['feat'][global_block.srcdata[dgl.NID]]
    local_feat = complete_graph.ndata['feat'][local_block.srcdata[dgl.NID]]

    # Initialize the CustomSAGEConv layer
    out_feats = 32
    activation = torch.nn.functional.relu
    conv = CustomSAGEConv(in_feats, out_feats, activation)
    print("CustomSAGEConv:", conv)

    # Perform forward pass
    h, activation_nid = conv(local_block, global_block, local_feat, global_feat)
    print(h)
    print(activation_nid)

    assert h.shape[0] == activation_nid.shape[0]

    # Check the input nid to the next layer
    local_block, global_block = blocks[1]
    local_src_nid = local_block.srcdata[dgl.NID]
    global_src_nid = global_block.srcdata[dgl.NID]
    print("Local src nid:", local_src_nid)
    print("Global src nid:", global_src_nid)
    assert activation_nid.shape[0] == len(local_src_nid) + len(global_src_nid)



    